from otree.api import *
import random
import math

doc = """
PGG with judges. 
Treatments: 
- no_judge: Standard PGG.
- AI_judge: No human judge; AI punishes based on formula.
- human_judge: Human judge punishes manually.
- humanAI_judge: AI suggests punishment, Human judge can edit.
"""

class C(BaseConstants):
    NAME_IN_URL = 'base'
    PLAYERS_PER_GROUP = None
    NUM_ROUNDS = 3
    ENDOWMENT = 20
    MPCR = 0.4

class Subsession(BaseSubsession):
    pass

class Group(BaseGroup):
    pass

class Player(BasePlayer):
    is_judge = models.BooleanField(initial=False)
    subgroup = models.IntegerField(initial=0)
    pgg_group = models.IntegerField(initial=0)

    # PGG fields
    contribution = models.IntegerField(min=0, max=20, initial=0)
    payoff_pre = models.FloatField(initial=0.0)
    
    # Punishment fields (received by victim)
    punishment_received = models.IntegerField(initial=0) 
    group_punishment_cost = models.FloatField(initial=0.0)

    # Judge Placeholder Fields (Max 12 normal players per subgroup)
    p1_punish = models.IntegerField(min=0, max=5, initial=0)
    p2_punish = models.IntegerField(min=0, max=5, initial=0)
    p3_punish = models.IntegerField(min=0, max=5, initial=0)
    p4_punish = models.IntegerField(min=0, max=5, initial=0)
    p5_punish = models.IntegerField(min=0, max=5, initial=0)
    p6_punish = models.IntegerField(min=0, max=5, initial=0)
    p7_punish = models.IntegerField(min=0, max=5, initial=0)
    p8_punish = models.IntegerField(min=0, max=5, initial=0)
    p9_punish = models.IntegerField(min=0, max=5, initial=0)
    p10_punish = models.IntegerField(min=0, max=5, initial=0)
    p11_punish = models.IntegerField(min=0, max=5, initial=0)
    p12_punish = models.IntegerField(min=0, max=5, initial=0)

# --- Norms Pre Fields ---
    # Personal Norm: "What should one contribute?"
    np_0 = models.IntegerField(label="Avg contrib 0:", min=0, max=20)
    np_5 = models.IntegerField(label="Avg contrib 5:", min=0, max=20)
    np_10 = models.IntegerField(label="Avg contrib 10:", min=0, max=20)
    np_15 = models.IntegerField(label="Avg contrib 15:", min=0, max=20)
    np_20 = models.IntegerField(label="Avg contrib 20:", min=0, max=20)

    # Normative Expectations: "What do others say one should contribute?"
    nn_0 = models.IntegerField(label="Avg contrib 0:", min=0, max=20)
    nn_5 = models.IntegerField(label="Avg contrib 5:", min=0, max=20)
    nn_10 = models.IntegerField(label="Avg contrib 10:", min=0, max=20)
    nn_15 = models.IntegerField(label="Avg contrib 15:", min=0, max=20)
    nn_20 = models.IntegerField(label="Avg contrib 20:", min=0, max=20)

    # Empirical Expectations: "What will be avg contribution?"
    expected_contribution = models.IntegerField(label="Avg contribution session:", min=0, max=20)

    # --- Norms Post Fields (Suffix _post) ---
    np_0_post = models.IntegerField(label="Avg contrib 0:", min=0, max=20)
    np_5_post = models.IntegerField(label="Avg contrib 5:", min=0, max=20)
    np_10_post = models.IntegerField(label="Avg contrib 10:", min=0, max=20)
    np_15_post = models.IntegerField(label="Avg contrib 15:", min=0, max=20)
    np_20_post = models.IntegerField(label="Avg contrib 20:", min=0, max=20)

    nn_0_post = models.IntegerField(label="Avg contrib 0:", min=0, max=20)
    nn_5_post = models.IntegerField(label="Avg contrib 5:", min=0, max=20)
    nn_10_post = models.IntegerField(label="Avg contrib 10:", min=0, max=20)
    nn_15_post = models.IntegerField(label="Avg contrib 15:", min=0, max=20)
    nn_20_post = models.IntegerField(label="Avg contrib 20:", min=0, max=20)

    expected_contribution_post = models.IntegerField(label="Avg contribution session:", min=0, max=20)


# =====================================================
# PAGES
# =====================================================

class Intro(Page):
    @staticmethod
    def is_displayed(player: Player):
        return player.round_number == 1
    @staticmethod
    def vars_for_template(player: Player):
        return dict(treatment=player.session.config.get('treatment_type'))

class SetupWaitPage(WaitPage):
    @staticmethod
    def after_all_players_arrive(group: Group):
        subsession = group.subsession
        players = subsession.get_players()
        treatment = subsession.session.config.get('treatment_type')

        # 1. Round 1: Assign fixed judges and subgroups
        if subsession.round_number == 1:
            for p in players:
                p.participant.vars['is_judge'] = False
                p.participant.vars['subgroup'] = 0

            # ONLY human judge treatments get actual judge players
            if treatment in ['human_judge', 'humanAI_judge']:
                # Judges are Player 1 and Player 2
                for p in players:
                    if p.id_in_subsession == 1:
                        p.participant.vars['is_judge'] = True
                        p.participant.vars['subgroup'] = 1
                    elif p.id_in_subsession == 2:
                        p.participant.vars['is_judge'] = True
                        p.participant.vars['subgroup'] = 2

                normal_players = [p for p in players if not p.participant.vars['is_judge']]
                N = len(normal_players)
                
                # Split logic
                valid_splits = []
                for a in range(4, N, 4):
                    b = N - a
                    if b >= 4 and b % 4 == 0:
                        valid_splits.append((a, b))
                
                if not valid_splits:
                    if N % 8 == 0: a, b = N//2, N//2
                    else: raise ValueError(f"Cannot split {N} players into subgroups divisible by 4.")
                else:
                    a, b = min(valid_splits, key=lambda x: abs(x[0] - x[1]))

                random.shuffle(normal_players)
                for p in normal_players[:a]: p.participant.vars['subgroup'] = 1
                for p in normal_players[a:a+b]: p.participant.vars['subgroup'] = 2
            
            else: 
                # no_judge AND AI_judge: Everyone is normal, everyone in Subgroup 1
                for p in players: 
                    p.participant.vars['subgroup'] = 1

        # 2. Every Round: Reset values
        for p in players:
            p.is_judge = p.participant.vars.get('is_judge', False)
            p.subgroup = p.participant.vars.get('subgroup', 0)
            p.pgg_group = 0
            
            # Reset punishment fields
            p.punishment_received = 0
            p.group_punishment_cost = 0.0
            if p.is_judge:
                for i in range(1, 13):
                     setattr(p, f'p{i}_punish', 0)

        # 3. Every Round: Reshuffle PGG groups
        normal_players = [p for p in players if not p.is_judge]
        global_group_id = 1
        for sg in sorted(set(p.subgroup for p in normal_players)):
            sg_players = [p for p in normal_players if p.subgroup == sg]
            random.shuffle(sg_players)
            for i in range(0, len(sg_players), 4):
                for p in sg_players[i:i+4]:
                    p.pgg_group = global_group_id
                global_group_id += 1

class Instructions(Page):

    @staticmethod
    def is_displayed(player: Player):
        return player.round_number == 1
        return True

    @staticmethod
    def vars_for_template(player: Player):
        return dict(
            treatment=player.session.config.get('treatment_type'),
            is_judge=player.is_judge,
            endowment=C.ENDOWMENT,
            mpcr=C.MPCR,
        )

class NormsPersonal(Page):
    form_model = 'player'
    form_fields = ['np_0', 'np_5', 'np_10', 'np_15', 'np_20']
    @staticmethod
    def is_displayed(player: Player):
        return player.round_number == 1

class NormsNormative(Page):
    form_model = 'player'
    form_fields = ['nn_0', 'nn_5', 'nn_10', 'nn_15', 'nn_20']
    @staticmethod
    def is_displayed(player: Player):
        return player.round_number == 1

class NormsEmpirical(Page):
    form_model = 'player'
    form_fields = ['expected_contribution']
    @staticmethod
    def is_displayed(player: Player):
        return player.round_number == 1


class Cooperation(Page):
    form_model = 'player'
    form_fields = ['contribution']
    @staticmethod
    def is_displayed(player: Player):
        return not player.is_judge
    @staticmethod
    def vars_for_template(player: Player):
        return dict(endowment=C.ENDOWMENT, mpcr=C.MPCR)

class ResultsWaitPage(WaitPage):
    @staticmethod
    def after_all_players_arrive(group: Group):
        players = group.subsession.get_players()
        normal_players = [p for p in players if not p.is_judge]
        
        # 1. Calculate PGG outcomes
        pgg_ids = set(p.pgg_group for p in normal_players)
        for g_id in pgg_ids:
            members = [p for p in normal_players if p.pgg_group == g_id]
            total_contrib = sum(m.contribution for m in members)
            for m in members:
                m.payoff_pre = float(C.ENDOWMENT - m.contribution + (C.MPCR * total_contrib))
                m.payoff = math.ceil(m.payoff_pre)

        # 2. Calculate Judge Payoff (Average of Pre-Punishment Payoffs)
        judges = [p for p in players if p.is_judge]
        for j in judges:
            sg_players = [p for p in normal_players if p.subgroup == j.subgroup]
            if sg_players:
                avg_payoff = float(sum(p.payoff for p in sg_players) / len(sg_players))
                j.payoff = math.ceil(avg_payoff)
            else:
                j.payoff = 0

class JudgeWaitPage(WaitPage):
    @staticmethod
    def after_all_players_arrive(group: Group):
        subsession = group.subsession
        treatment = subsession.session.config.get('treatment_type')

        if treatment == 'no_judge':
            return

        # LOGIC FOR AI_judge (No human judges exist)
        if treatment == 'AI_judge':
            normal_players = subsession.get_players() # All players are normal
            subgroups = set(p.subgroup for p in normal_players)
            
            for sg in subgroups:
                sg_players = [p for p in normal_players if p.subgroup == sg]
                if not sg_players: continue
                
                avg_contrib = sum(p.contribution for p in sg_players) / len(sg_players)
                
                for p in sg_players:
                    deviation = avg_contrib - p.contribution
                    # Formula: min(5, max(0, round((avg - contrib) / 2)))
                    points = min(5, max(0, round(deviation / 2)))
                    p.punishment_received = points
            return

        # LOGIC FOR humanAI_judge (Pre-fill human judges)
        if treatment == 'humanAI_judge':
            judges = [p for p in subsession.get_players() if p.is_judge]
            
            for judge in judges:
                sg_players = [p for p in subsession.get_players() 
                              if not p.is_judge and p.subgroup == judge.subgroup]
                if not sg_players: continue

                avg_contrib = sum(p.contribution for p in sg_players) / len(sg_players)

                for i, p in enumerate(sg_players):
                    deviation = avg_contrib - p.contribution
                    points = min(5, max(0, round(deviation / 2)))
                    setattr(judge, f'p{i+1}_punish', points)

class Judge(Page):
    form_model = 'player'

    @staticmethod
    def is_displayed(player: Player):
        treatment = player.session.config.get('treatment_type')
        return player.is_judge and treatment in ['human_judge', 'humanAI_judge']

    @staticmethod
    def get_form_fields(player: Player):
        sg_players = [p for p in player.subsession.get_players() 
                      if not p.is_judge and p.subgroup == player.subgroup]
        return [f'p{i+1}_punish' for i in range(len(sg_players))]

    @staticmethod
    def vars_for_template(player: Player):
        sg_players = [p for p in player.subsession.get_players() 
                      if not p.is_judge and p.subgroup == player.subgroup]
        
        group_items = []
        for i, p in enumerate(sg_players):
            group_items.append({'player': p, 'field': f'p{i+1}_punish'})
            
        pgg_groups = {}
        for item in group_items:
            gid = item['player'].pgg_group
            pgg_groups.setdefault(gid, []).append(item)
            
        return dict(pgg_groups=pgg_groups)

    @staticmethod
    def before_next_page(player: Player, timeout_happened):
        sg_players = [p for p in player.subsession.get_players() 
                      if not p.is_judge and p.subgroup == player.subgroup]
        
        for i, p in enumerate(sg_players):
            val = getattr(player, f'p{i+1}_punish')
            p.punishment_received = val

class FinalWaitPage(WaitPage):
    @staticmethod
    def after_all_players_arrive(group: Group):
        players = group.subsession.get_players()
        normal_players = [p for p in players if not p.is_judge]

        # Apply Punishments
        pgg_ids = set(p.pgg_group for p in normal_players)
        for g_id in pgg_ids:
            members = [p for p in normal_players if p.pgg_group == g_id]
            
            total_punish = sum(m.punishment_received for m in members)
            shared_cost = total_punish / 4.0
            
            for m in members:
                m.group_punishment_cost = shared_cost
                final_val = float(m.payoff_pre - m.punishment_received - shared_cost)
                m.payoff = math.ceil(final_val)

class Results(Page):
    @staticmethod
    def vars_for_template(player: Player):
        if player.is_judge: return dict()

        group_members = [p for p in player.subsession.get_players() 
                         if not p.is_judge and p.pgg_group == player.pgg_group]
        
        anon_members = []
        for i, m in enumerate(group_members):
            anon_members.append({
                'label': chr(65+i), 
                'contribution': m.contribution, 
                'punishment': m.punishment_received
            })
            
        total_punish = sum(m.punishment_received for m in group_members)
        total_contribution = sum(m.contribution for m in group_members)
        return dict(
            anon_members=anon_members, 
            total_group_punish=total_punish,
            total_contribution=total_contribution
        )

class NormsPersonalPost(Page):
    form_model = 'player'
    form_fields = ['np_0_post', 'np_5_post', 'np_10_post', 'np_15_post', 'np_20_post']
    @staticmethod
    def is_displayed(player: Player):
        return player.round_number == C.NUM_ROUNDS

class NormsNormativePost(Page):
    form_model = 'player'
    form_fields = ['nn_0_post', 'nn_5_post', 'nn_10_post', 'nn_15_post', 'nn_20_post']
    @staticmethod
    def is_displayed(player: Player):
        return player.round_number == C.NUM_ROUNDS

class NormsEmpiricalPost(Page):
    form_model = 'player'
    form_fields = ['expected_contribution_post']
    @staticmethod
    def is_displayed(player: Player):
        return player.round_number == C.NUM_ROUNDS

class SetupInfo(Page):
    @staticmethod
    def vars_for_template(player: Player):
        return dict(players=player.subsession.get_players())

page_sequence = [
    Intro,
    SetupWaitPage,
    Instructions,
    # Pre Norms
    NormsPersonal,
    NormsNormative,
    NormsEmpirical,
    # PGG
    Cooperation,
    ResultsWaitPage,
    JudgeWaitPage,
    Judge,
    FinalWaitPage,
    Results,
    # Post Norms
    NormsPersonalPost,
    NormsNormativePost,
    NormsEmpiricalPost,
]

